package top.houry.limit;

import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.locks.Condition;
import java.util.concurrent.locks.ReentrantLock;

public class LeakyBucketRateLimiter {
    private final long capacity; // 桶的容量
    private final long rate;     // 令牌填充的速率（QPS）
    private final ReentrantLock lock = new ReentrantLock();
    private final Condition tokensAvailable = lock.newCondition();
    private final AtomicLong availableTokens = new AtomicLong(0);
    private long lastRefillTimestamp;

    public LeakyBucketRateLimiter(long rate, long capacity) {
        this.rate = rate;
        this.capacity = capacity;
        this.availableTokens.set(capacity);
        this.lastRefillTimestamp = System.currentTimeMillis();
    }

    /**
     * 在这个方法中，我们首先填充令牌，如果桶中有令牌，则消耗一个令牌并处理请求。如果桶中没有令牌，则线程会等待，直到被 refillTokens 方法唤醒。
     * @throws InterruptedException
     */
    public void acquire() throws InterruptedException {
        lock.lock();
        try {
            while (true) {
                refillTokens();
                if (availableTokens.get() > 0) {
                    availableTokens.decrementAndGet();
                    return;
                } else {
                    // 等待直到令牌变得可用
                    tokensAvailable.await();
                }
            }
        } finally {
            lock.unlock();
        }
    }

    /**
     * 在这个方法中，我们计算了自上次填充以来应该添加到桶中的令牌数，并更新了桶中的令牌数量和最后填充时间戳。如果添加了令牌，我们使用 signalAll 方法来唤醒所有等待的线程。
     */
    private void refillTokens() {
        long now = System.currentTimeMillis();
        long elapsedTime = now - lastRefillTimestamp;
        long tokensToAdd = (elapsedTime / 1000) * rate;
        if (tokensToAdd > 0) {
            long newTokens = Math.min(capacity, availableTokens.get() + tokensToAdd);
            availableTokens.set(newTokens);
            lastRefillTimestamp = now;
            // 唤醒所有等待的线程，因为令牌已经变得可用
            lock.lock();
            try {
                tokensAvailable.signalAll();
            } finally {
                lock.unlock();
            }
        }
    }

    public static void main(String[] args) throws InterruptedException {
        LeakyBucketRateLimiter rateLimiter = new LeakyBucketRateLimiter(100, 100);

        // 模拟多个线程中的请求
        for (int i = 0; i < 150; i++) {
            new Thread(() -> {
                try {
                    rateLimiter.acquire();
                    System.out.println("Request " + Thread.currentThread().getId() + " passed.");
                } catch (InterruptedException e) {
                    Thread.currentThread().interrupt();
                }
            }).start();
//            Thread.sleep(10); // 模拟请求间隔
        }
    }
}